/*
 * Copyright (C) 2004-2008 Jive Software, 2016-2024 Ignite Realtime Foundation. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jivesoftware.openfire.container;

import org.dom4j.Document;
import org.jivesoftware.admin.FlashMessageTag;
import org.jivesoftware.admin.PluginFilter;
import org.jivesoftware.openfire.XMPPServer;
import org.jivesoftware.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.servlet.*;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * The plugin servlet acts as a proxy for web requests (in the admin console)
 * to plugins. Since plugins can be dynamically loaded and live in a different place
 * than normal Openfire admin console files, it's not possible to have them
 * added to the normal Openfire admin console web app directory.
 * <p>
 * The servlet listens for requests in the form {@code /plugins/[pluginName]/[JSP File]}
 * (e.g. {@code /plugins/foo/example.jsp}). It also listens for non JSP requests in the
 * form like {@code /plugins/[pluginName]/images/*.png|gif},
 * {@code /plugins/[pluginName]/scripts/*.js|css} or
 * {@code /plugins/[pluginName]/styles/*.css} (e.g.
 * {@code /plugins/foo/images/example.gif}).
 * </p>
 * JSP files must be compiled and available via the plugin's class loader. The mapping
 * between JSP name and servlet class files is defined in [pluginName]/web/web.xml.
 * Typically, this file is auto-generated by the JSP compiler when packaging the plugin.
 *
 * @author Matt Tucker
 */
public class PluginServlet extends HttpServlet {

    private static final Logger Log = LoggerFactory.getLogger(PluginServlet.class);
    private static final String CSRF_ATTRIBUTE = "csrf";

    public static final SystemProperty<Boolean> ALLOW_LOCAL_FILE_READING = SystemProperty.Builder.ofType( Boolean.class )
        .setKey( "plugins.servlet.allowLocalFileReading" )
        .setDynamic( true )
        .setDefaultValue( false )
        .build();

    private static final Map<String, GenericServlet> servlets;  // mapped using lowercase path (OF-1105)
    private static PluginManager pluginManager;
    private static ServletConfig servletConfig;

    static {
        servlets = new ConcurrentHashMap<>();
    }
    
    private static final String PLUGINS_WEBROOT = "/plugins/";

    @Override
    public void init(ServletConfig config) throws ServletException {
        super.init(config);
        servletConfig = config;
    }

    @Override
    public void service(HttpServletRequest request, HttpServletResponse response) {
        String pathInfo = request.getPathInfo();
        if (pathInfo == null) {
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
        }
        else {
            try {
                final PluginMetadata pluginMetadata = getPluginMetadataFromPath(pathInfo);
                if (pluginMetadata.isCsrfProtectionEnabled()) {
                    if (!passesCsrf(request)) {
                        request.getSession().setAttribute(FlashMessageTag.ERROR_MESSAGE_KEY, LocaleUtils.getLocalizedString("global.csrf.failed"));
                        response.sendRedirect(request.getRequestURI());
                        return;
                    }
                }
                // Handle JSP requests.
                if (pathInfo.endsWith(".jsp")) {
                    setCSRF(request);
                    handleJSP(pathInfo, request, response);
                }
                // Handle servlet requests.
                else if (getServlet(pathInfo) != null) {
                    setCSRF(request);
                    handleServlet(pathInfo, request, response);
                }
                // Handle image/other requests.
                else {
                    handleOtherRequest(pathInfo, response);
                }
            }
            catch (Exception e) {
                Log.error(e.getMessage(), e);
                response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
            }
        }
    }

    private static void setCSRF(final HttpServletRequest request) {
        final String csrf = StringUtils.randomString(32);
        request.getSession().setAttribute(CSRF_ATTRIBUTE, csrf);
        request.setAttribute(CSRF_ATTRIBUTE, csrf);
    }

    private boolean passesCsrf(final HttpServletRequest request) {
        if (request.getMethod().equals("GET")) {
            // No CSRF's for GET requests
            return true;
        }

        final String sessionCsrf = (String) request.getSession().getAttribute(CSRF_ATTRIBUTE);
        return sessionCsrf != null && sessionCsrf.equals(request.getParameter(CSRF_ATTRIBUTE));
    }

    private PluginMetadata getPluginMetadataFromPath(final String pathInfo) {
        final String pluginName = pathInfo.split("/")[1];
        return XMPPServer.getInstance().getPluginManager().getMetadata(pluginName);
    }

    /**
     * Registers all JSP page servlets for a plugin.
     *
     * @param manager the plugin manager.
     * @param plugin the plugin.
     * @param webXML the web.xml file containing JSP page names to servlet class file
     *      mappings.
     */
    public static void registerServlets( PluginManager manager, final Plugin plugin, File webXML)
    {
        pluginManager = manager;

        if ( !webXML.exists() )
        {
            Log.error("Could not register plugin servlets, file " + webXML.getAbsolutePath() + " does not exist.");
            return;
        }

        // Find the name of the plugin directory given that the webXML file lives in plugins/[pluginName]/web/web.xml
        final String pluginName = webXML.getParentFile().getParentFile().getParentFile().getName();
        try
        {
            final Document webXmlDoc = WebXmlUtils.asDocument( webXML );

            final List<String> servletNames = WebXmlUtils.getServletNames( webXmlDoc );
            for ( final String servletName : servletNames )
            {
                Log.debug( "Loading servlet '{}' of plugin '{}'...", servletName, pluginName );

                final String className = WebXmlUtils.getServletClassName( webXmlDoc, servletName );
                if ( className == null || className.isEmpty() )
                {
                    Log.warn( "Could not load servlet '{}' of plugin '{}'. web-xml does not define a class name for this servlet.", servletName, pluginName );
                    continue;
                }

                try {
                    final Class<?> theClass = manager.loadClass(plugin, className);

                    final Object instance = theClass.newInstance();
                    if (!(instance instanceof GenericServlet)) {
                        Log.warn("Could not load servlet '{}' of plugin '{}'. Its class ({}) is not an instance of javax.servlet.GenericServlet.", servletName, pluginName, className);
                        continue;
                    }

                    Log.debug("Initializing servlet '{}' of plugin '{}'...", servletName, pluginName);
                    ((GenericServlet) instance).init(new ServletConfig() {
                        @Override
                        public String getServletName() {
                            return servletName;
                        }

                        @Override
                        public ServletContext getServletContext() {
                            return new PluginServletContext( servletConfig.getServletContext(), pluginManager, plugin );
                        }

                        @Override
                        public String getInitParameter( String s )
                        {
                            final Map<String, String> params = WebXmlUtils.getServletInitParams( webXmlDoc, servletName );
                            if (params.isEmpty()) {
                                return null;
                            }
                            return params.get( s );
                        }

                        @Override
                        public Enumeration<String> getInitParameterNames()
                        {
                            final Map<String, String> params = WebXmlUtils.getServletInitParams( webXmlDoc, servletName );
                            if (params.isEmpty()) {
                                return Collections.emptyEnumeration();
                            }
                            return Collections.enumeration( params.keySet() );
                        }
                    });

                    Log.debug("Registering servlet '{}' of plugin '{}' URL patterns.", servletName, pluginName);
                    final Set<String> urlPatterns = WebXmlUtils.getServletUrlPatterns(webXmlDoc, servletName);
                    for (final String urlPattern : urlPatterns) {
                        final String path = (pluginName + urlPattern).toLowerCase();
                        servlets.put(path, (GenericServlet) instance);
                        Log.debug("Servlet '{}' registered on path: {}", servletName, path);
                    }
                    Log.debug("Servlet '{}' of plugin '{}' loaded successfully.", servletName, pluginName);
                } catch (final Exception e) {
                    Log.warn("Exception attempting to load servlet '{}' ({}) of plugin '{}'", servletName, className, pluginName, e);
                }
            }


            final List<String> filterNames = WebXmlUtils.getFilterNames( webXmlDoc );
            for ( final String filterName : filterNames )
            {
                Log.debug( "Loading filter '{}' of plugin '{}'...", filterName, pluginName );
                final String className = WebXmlUtils.getFilterClassName( webXmlDoc, filterName );
                if ( className == null || className.isEmpty() )
                {
                    Log.warn( "Could not load filter '{}' of plugin '{}'. web-xml does not define a class name for this filter.", filterName, pluginName );
                    continue;
                }
                final Class<?> theClass = manager.loadClass( plugin, className );

                final Object instance = theClass.newInstance();
                if ( !(instance instanceof Filter) )
                {
                    Log.warn( "Could not load filter '{}' of plugin '{}'. Its class ({}) is not an instance of javax.servlet.Filter.", filterName, pluginName, className );
                    continue;
                }

                Log.debug( "Initializing filter '{}' of plugin '{}'...", filterName, pluginName );
                ( (Filter) instance ).init( new FilterConfig()
                {
                    @Override
                    public String getFilterName()
                    {
                        return filterName;
                    }

                    @Override
                    public ServletContext getServletContext()
                    {
                        return new PluginServletContext( servletConfig.getServletContext(), pluginManager, plugin );
                    }

                    @Override
                    public String getInitParameter( String s )
                    {
                        final Map<String, String> params = WebXmlUtils.getFilterInitParams( webXmlDoc, filterName );
                        if (params.isEmpty()) {
                            return null;
                        }
                        return params.get( s );
                    }

                    @Override
                    public Enumeration<String> getInitParameterNames()
                    {
                        final Map<String, String> params = WebXmlUtils.getFilterInitParams( webXmlDoc, filterName );
                        if (params.isEmpty()) {
                            return Collections.emptyEnumeration();
                        }
                        return Collections.enumeration( params.keySet() );
                    }
                } );

                Log.debug( "Registering filter '{}' of plugin '{}' URL patterns.", filterName, pluginName );
                final Set<String> urlPatterns = WebXmlUtils.getFilterUrlPatterns( webXmlDoc, filterName );
                for ( final String urlPattern : urlPatterns )
                {
                    PluginFilter.addPluginFilter( urlPattern, ( (Filter) instance ) );
                }
                Log.debug( "Filter '{}' of plugin '{}' loaded successfully.", filterName, pluginName );
            }
        }
        catch (Throwable e)
        {
            Log.error( "An unexpected problem occurred while attempting to register servlets for plugin '{}'.", plugin, e);
        }
    }

    /**
     * Unregisters all JSP page servlets for a plugin.
     *
     * @param webXML the web.xml file containing JSP page names to servlet class file
     *               mappings.
     */
    public static void unregisterServlets(File webXML)
    {
        if ( !webXML.exists() )
        {
            Log.error("Could not unregister plugin servlets, file " + webXML.getAbsolutePath() + " does not exist.");
            return;
        }

        // Find the name of the plugin directory given that the webXML file lives in plugins/[pluginName]/web/web.xml
        final String pluginName = webXML.getParentFile().getParentFile().getParentFile().getName();
        try
        {
            final Document webXmlDoc = WebXmlUtils.asDocument( webXML );

            // Un-register and destroy all servlets.
            final List<String> servletNames = WebXmlUtils.getServletNames( webXmlDoc );
            for ( final String servletName : servletNames )
            {
                Log.debug( "Unregistering servlet '{}' of plugin '{}'", servletName, pluginName );
                final Set<Servlet> toDestroy = new HashSet<>();
                final Set<String> urlPatterns = WebXmlUtils.getServletUrlPatterns( webXmlDoc, servletName );
                for ( final String urlPattern : urlPatterns )
                {
                    final GenericServlet servlet = servlets.remove( ( pluginName + urlPattern ).toLowerCase() );
                    if (servlet != null)
                    {
                        toDestroy.add( servlet );
                    }
                }

                for ( final Servlet servlet : toDestroy )
                {
                    servlet.destroy();
                }

                Log.debug( "Servlet '{}' of plugin '{}' unregistered and destroyed successfully.", servletName, pluginName );

            }

            // Un-register and destroy all servlet filters.
            final List<String> filterNames = WebXmlUtils.getFilterNames( webXmlDoc );
            for ( final String filterName : filterNames )
            {
                Log.debug( "Unregistering filter '{}' of plugin '{}'", filterName, pluginName );
                final Set<Filter> toDestroy = new HashSet<>();
                final String className = WebXmlUtils.getFilterClassName( webXmlDoc, filterName );
                final Set<String> urlPatterns = WebXmlUtils.getFilterUrlPatterns( webXmlDoc, filterName );
                for ( final String urlPattern : urlPatterns )
                {
                    final Filter filter = PluginFilter.removePluginFilter( urlPattern, className );
                    if (filter != null)
                    {
                        toDestroy.add( filter );
                    }
                }

                for ( final Filter filter : toDestroy )
                {
                    filter.destroy();
                }

                Log.debug( "Filter '{}' of plugin '{}' unregistered and destroyesd successfully.", filterName, pluginName );
            }
        }
        catch (Throwable e) {
            Log.error( "An unexpected problem occurred while attempting to unregister servlets.", e);
        }
    }

    /**
     * Registers a live servlet for a plugin programmatically, does not
     * initialize the servlet.
     * 
     * @param pluginManager the plugin manager
     * @param plugin the owner of the servlet
     * @param servlet the servlet.
     * @param relativeUrl the relative url where the servlet should be bound
     * @return the effective url that can be used to initialize the servlet
     * @throws ServletException if the servlet is null
     */
    public static String registerServlet(PluginManager pluginManager,
            Plugin plugin, GenericServlet servlet, String relativeUrl)
            throws ServletException {

        String pluginName = pluginManager.getPluginPath(plugin).getFileName().toString();
        PluginServlet.pluginManager = pluginManager;
        if (servlet == null) {
            throw new ServletException("Servlet is missing");
        }
        String pluginServletUrl = pluginName + relativeUrl;
        servlets.put((pluginName + relativeUrl).toLowerCase(), servlet);
        return PLUGINS_WEBROOT + pluginServletUrl;
        
    }

    /**
     * Unregister a live servlet for a plugin programmatically. Does not call
     * the servlet destroy method.
     * 
     * @param plugin the owner of the servlet
     * @param url the relative url where servlet has been bound
     * @return the unregistered servlet, so that it can be destroyed
     * @throws ServletException if the URL is missing
     */
    public static GenericServlet unregisterServlet(Plugin plugin, String url)
            throws ServletException {
        String pluginName = pluginManager.getPluginPath(plugin).getFileName().toString();
        if (url == null) {
            throw new ServletException("Servlet URL is missing");
        }
        String fullUrl = pluginName + url;
        return servlets.remove(fullUrl.toLowerCase());
    }
    
    /**
     * Handles a request for a JSP page. It checks to see if a servlet is mapped
     * for the JSP URL. If one is found, request handling is passed to it. If no
     * servlet is found, a 404 error is returned.
     *
     * @param pathInfo the extra path info.
     * @param request  the request object.
     * @param response the response object.
     * @throws ServletException if a servlet exception occurs while handling the request.
     * @throws IOException      if an IOException occurs while handling the request.
     */
    private void handleJSP(String pathInfo, HttpServletRequest request,
                           HttpServletResponse response) throws ServletException, IOException {
        // Strip the starting "/" from the path to find the JSP URL.
        String jspURL = pathInfo.substring(1);
        GenericServlet servlet = servlets.get(jspURL.toLowerCase());
        if (servlet != null) {
            Log.trace("Handling JSP at: {}", jspURL);
            servlet.service(request, response);
        }
        else {
            Log.trace("Unable to handle JSP. No registration for: {}", jspURL);
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
        }
    }

    /**
     * Handles a request for a Servlet. If one is found, request handling is passed to it.
     * If no servlet is found, a 404 error is returned.
     *
     * @param pathInfo the extra path info.
     * @param request  the request object.
     * @param response the response object.
     * @throws ServletException if a servlet exception occurs while handling the request.
     * @throws IOException      if an IOException occurs while handling the request.
     */
    private void handleServlet(String pathInfo, HttpServletRequest request,
                               HttpServletResponse response) throws ServletException, IOException {
        // Strip the starting "/" from the path to find the JSP URL.
        GenericServlet servlet = getServlet(pathInfo);
        if (servlet != null) {
            Log.trace("Handling servlet at: {}", pathInfo);
            servlet.service(request, response);
        }
        else {
            Log.trace("Unable to handle servlet. No registration for: {}", pathInfo);
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
        }
    }

    /**
     * Returns the correct servlet with mapping checks.
     *
     * @param pathInfo the pathinfo to map to the servlet.
     * @return the mapped servlet, or null if no servlet was found.
     */
    private GenericServlet getServlet(String pathInfo) {
        pathInfo = pathInfo.substring(1).toLowerCase();
        final GenericServlet servlet = getWildcardMappedObject(servlets, pathInfo);
        Log.trace("Found servlet {} for path {}", servlet != null ? servlet.getServletName() : "(none)", pathInfo);
        return servlet;
    }

    // Package protected to facilitate unit testing
    static <T> T getWildcardMappedObject(final Map<String, T> mapping, final String query) {
        T value = mapping.get(query);
        if (value == null) {
            for (String key : mapping.keySet()) {
                // Turn the search key into a regex, using all characters but the * as a literal.
                String regex = Arrays.stream(key.split("\\*")) // split in parts that do not have a wildcard in them
                    .map(Pattern::quote) // each part should be used as a literal (not as a regex or partial regex)
                    .collect(Collectors.joining(".*")); // join all literal parts with a regex representation on the wildcard.
                if (key.endsWith("*")) { // the 'split' will have removed any trailing wildcard characters. Correct for that.
                    regex += ".*";
                }
                if (query.matches(regex)) {
                    value = mapping.get(key);
                    break;
                }
            }
        }
        return value;
    }

    /**
     * Handles a request for other web items (images, etc.)
     *
     * @param pathInfo the extra path info.
     * @param response the response object.
     * @throws IOException if an IOException occurs while handling the request.
     */
    private void handleOtherRequest(String pathInfo, HttpServletResponse response) throws IOException {
        String[] parts = pathInfo.split("/");
        // Image request must be in correct format.
        if (parts.length < 3) {
            Log.trace("Unable to handle 'other' request (not enough path parts): {}", pathInfo);
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
            return;
        }

        String contextPath = "";
        int index = pathInfo.indexOf(parts[1]);
        if (index != -1) {
            contextPath = pathInfo.substring(index + parts[1].length() + 1);
        }

        Path pluginDirectory = JiveGlobals.getHomePath().resolve("plugins");
        Path file = pluginDirectory.resolve(parts[1]).resolve("web").resolve(contextPath);

        if ( !ALLOW_LOCAL_FILE_READING.getValue() ) {
            // Ensure that the file that's being served is a file that is part of Openfire. This guards against
            // accessing files from the operating system, or other files that shouldn't be accessible via the web (OF-1886).
            final Path absoluteHome = JiveGlobals.getHomePath().normalize().toAbsolutePath();
            final Path absoluteLookup = file.normalize().toAbsolutePath();
            if ( !absoluteLookup.startsWith( absoluteHome ) )
            {
                Log.trace("Unable to handle 'other' request (forbidden path): {}", pathInfo);
                response.setStatus( HttpServletResponse.SC_FORBIDDEN );
                return;
            }
        }

        if (!Files.exists(file)) {
            Log.trace("Unable to handle 'other' request (not found): {}", pathInfo);
            response.setStatus(HttpServletResponse.SC_NOT_FOUND);
            return;
        }

        String contentType = getServletContext().getMimeType(pathInfo);
        if (contentType == null) {
            contentType = "text/plain";
        }
        response.setContentType(contentType);
        // Write out the resource to the user.
        Log.trace("Handling 'other' request with a response content type of {}: {}", contentType, pathInfo);
        // Set the size of the file.
        response.setContentLength((int) Files.size(file));
        Files.readAllBytes(file);

        try (final BufferedInputStream in = new BufferedInputStream(Files.newInputStream(file));
             final ServletOutputStream out = response.getOutputStream())
        {
            // Use a 32K buffer.
            final byte[] buf = new byte[32*1024];
            int len;
            while ((len = in.read(buf)) != -1) {
                out.write(buf, 0, len);
            }
        }
    }
}
